from Tools.Scripts.dutree import display

__author__ = 'Amit'

import sys
import re
import os
from serverHandlers import InvivoGenWizardSession
from serverHandlers import NupackSession
import time
from concurrent.futures import ThreadPoolExecutor

###################### Constants ######################

# Trigger:
aLen = 12
bLen = 14
cLen = 3
detectionTargetLen = aLen + bLen + cLen

# Silencing Target:
yLen = 2
zLen = 19
silencingTargetLen = yLen + zLen


###################### Main ######################

def run():
    js = JobSession()

###################### Job Session ######################

class JobSession:
    @staticmethod
    def parseInputFile(inputPathTargetGene, inputPathTriggerGene):
        with open(inputPathTargetGene) as f:
            data = f.readlines()
            targetGene = (''.join(data[:])).strip()
        with open(inputPathTriggerGene) as f:
            data = f.readlines()
            triggerGene = (''.join(data[:])).strip()
        return (targetGene, triggerGene)

    def __init__(self, pathToAllResultsFolder = 'All Results', jobTitle = 'run'):
        if (len(sys.argv) < 3):
            print('Not enough arguments.\nRun the program with the paths of the input files:\n> python scRNA.py trigger_file target_file')
            exit()

        self.paths = JSPaths(pathToAllResultsFolder, jobTitle)
        (self.targetGene, self.triggerGene) = JobSession.parseInputFile(sys.argv[2], sys.argv[1])
        self.targetList = createListOfTargets(self.targetGene, silencingTargetLen)
        self.targetList = [dna2rna(seq) for seq in self.targetList]
        bestSilencingTarget = rnaFormat(self.targetList[0])
        self.scRNAs = self.runOverTriggerWithTarget(bestSilencingTarget)
        scoring = ScRNAsScoring(self.scRNAs, self.triggerGene, self.paths)
        scoring.writeResults()

    def runOverTriggerWithTarget(self, target):
        self.scRNAs = []
        triggerRNA = rnaFormat(self.triggerGene)
        #print(triggerRNA)
        (y, z) = mRNAsilencingTarget(target)
        for i in range(0, len(triggerRNA)-detectionTargetLen+1):
            startBase = i + 1
            endBase = i + detectionTargetLen
            #print(triggerRNA)
            #print('i='+ str(i) +', i+dT='+ str(i+detectionTargetLen) \
            #      +' tG= '+ triggerRNA[i:(i+detectionTargetLen)])
            (a, b, c) = mRNAdetectionTarget(triggerRNA[i:(i+detectionTargetLen)])
            #print('a: ' + a)
            #print('b: ' + b)
            #print('c: ' + c)
            #print('y: ' + y)
            #print('z: ' + z)
            scRNA = createScRNA(a, b, c, y, z, startBase, endBase)
            self.scRNAs += [scRNA]
        return self.scRNAs


###################### scRNA Scoring ######################

class ScRNAsScoring:
    def __init__(self, scRNAs, triggerGene, paths):
        self.scRNAs = scRNAs
        self.triggerGene = triggerGene
        self.paths = paths
        self.concentrations = []
        #threads = []

        print('Starting to analyze ' + str(len(scRNAs)) + ' scRNAs.')
        numberOfThreads = 10
        pool = ThreadPoolExecutor(numberOfThreads)
        for scRNA in self.scRNAs:
            pool.submit(self.calculateConc, scRNA)
            time.sleep(numberOfThreads/len(scRNAs))
        counter = 0
        sameCounter = 0
        lastLen = 0
        while (len(self.concentrations) < len(self.scRNAs)):
            time.sleep(20)
            counter += 1
            if (counter % 3 == 0):
                if (lastLen == len(self.concentrations)):
                    sameCounter += 1
                else:
                    sameCounter = 0
                    lastLen = len(self.concentrations)
                if (sameCounter == 4):
                    break
                print(str(len(scRNAs)-len(self.concentrations)) + ' scRNAs remaining...')

        self.sortedListOfscRNA = Concentrations.getListOfBestScRNA(self.concentrations)


    def getBestScRNAs(self):
        return self.sortedListOfscRNA

    def calculateConc(self, scRNA):
        conc = \
            Concentrations(
                scRNA,
                self.analyzeScRNAComplexesWithoutTrigger(scRNA),
                self.analyzeScRNAComplexesWithTrigger(scRNA))
        self.concentrations.append(conc)

    def analyzeScRNAComplexesWithoutTrigger(self, scRNA):
        (scRNAfolder, scRNAzipFolder) = self.paths.getScRNAFolder(scRNA)
        createDirIfNotExists(scRNAfolder)
        createDirIfNotExists(scRNAzipFolder)
        nupack = NupackSession(scRNAzipFolder, scRNAfolder)
        sequences = [str(scRNA.hpA), str(scRNA.hpB)]
        titles = ['A', 'B']
        concentrations = [100, 100]
        concentrationTable = nupack.analyzeMultipleSpeciesConcentrations(sequences, titles, concentrations)
        if (not concentrationTable):
            print('Concentration table is empty.')
        return concentrationTable

    def analyzeScRNAComplexesWithTrigger(self, scRNA):
        (scRNAfolder, scRNAzipFolder) = self.paths.getScRNATriggerFolder(scRNA)
        createDirIfNotExists(scRNAfolder)
        createDirIfNotExists(scRNAzipFolder)
        nupack = NupackSession(scRNAzipFolder, scRNAfolder)
        (startBase, endBase) = scRNA.getStartEndTrigger()
        choppedTriggerGene = \
            pieceOfSeq(self.triggerGene, startBase, endBase, 150)
        sequences = [str(scRNA.hpA), str(scRNA.hpB), choppedTriggerGene]
        titles = ['A', 'B', 'Trigger Gene (chopped)']
        concentrations = [100, 100, 100]
        concentrationTable = nupack.analyzeMultipleSpeciesConcentrations(sequences, titles, concentrations)
        if (not concentrationTable):
            print('Concentration table is empty.')
        return concentrationTable

    def writeResults(self):
        pathOfResults = '/' + self.paths.id + ' - Output.txt'
        print('========== Results ==========')
        print('*** Run: ' + self.paths.id + '\n')
        for i in range(0, min(max(len(self.sortedListOfscRNA), 5), len(self.scRNAs))):
            print(str(self.sortedListOfscRNA[i][0]) + '\n')
        with open(self.paths.main() + pathOfResults, "w") as f:
            for e in self.sortedListOfscRNA:
                f.write('Score: ' + str(e[1]) + '\n')
                f.write(str(e[0]))
                f.write('\n\n')
            f.close()
        print('All the results are in the file: ' + self.paths.main()+pathOfResults)
        print('End.')


class Concentrations:
    def __init__(self, scRNA, concsScRNAwithoutTrigger, concsScRNAwithTrigger):
        self.scRNA = scRNA
        self.concsScRNAwithoutTrigger = concsScRNAwithoutTrigger
        self.concsScRNAwithTrigger = concsScRNAwithTrigger

        # A*B (without trigger)
        self.concABwithoutTrigger = \
            self.findConcentrationOfComplex(self.concsScRNAwithoutTrigger, ('1','1'))

        # B & X*A (chopped)
        self.concBWithTrigger = \
            self.findConcentrationOfComplex(self.concsScRNAwithTrigger, ('0','1','0'))
        self.concXAWithTrigger = \
            self.findConcentrationOfComplex(self.concsScRNAwithTrigger, ('1','0','1'))
        self.concXAXAWithTrigger = \
            self.findConcentrationOfComplex(self.concsScRNAwithTrigger, ('2','0','2'))

    def getScRNA(self):             return self.scRNA
    def getABnoTrigger(self):       return self.concABwithoutTrigger
    def getBtrigger(self):          return self.concBWithTrigger
    def getXAtrigger(self):         return self.concXAWithTrigger + self.concXAXAWithTrigger

    def findConcentrationOfComplex(self, concTbl, complex, topNumber = 0):
        if (not topNumber):
            untilIndex = len(concTbl)
        else:
            untilIndex = topNumber
        for i in range(0, untilIndex):
            if(concTbl[i][0] == complex):
                return concTbl[i][2]
        return 0

    @staticmethod
    def getListOfBestScRNA(concsList):
        scRNAscoreList = []
        for c in concsList:
            scoreNoTrigger = c.getABnoTrigger()
            scoreTrigger = c.getBtrigger()
            scoreTot = scoreNoTrigger * scoreTrigger
            scRNAscoreList += [(c.getScRNA(), scoreTot)]
        scRNAscoreList.sort(key = (lambda x: x[1]), reverse = True)

        for line in scRNAscoreList:
            print(line[1])

        return scRNAscoreList


    #def getScRNAByLargestSeparation(self, data):
    #    maxScore = (None, 0)
    #    for concAndScRNA in data:
    #        scRNA = concAndScRNA[0]
    #        concTbl = concAndScRNA[1]
    #        ASconc = self.findConcentrationOfComplex(concTbl, ('1', '0'))
    #        HpSSconc = self.findConcentrationOfComplex(concTbl, ('0', '1'))
    #        currScore = (scRNA, ASconc*HpSSconc)
    #        if (currScore[1] > maxScore[1]):
    #            maxScore = currScore
    #    return maxScore[0]




###################### Paths ######################

class JSPaths:
    dictScRNA = dict()
    dictScRNATrigger = dict()

    def __init__(self, pathToAllResultsFolder, jobSessionTitle):
        self.id = time.strftime("%y%m%d%H%M%S") + ' ' + jobSessionTitle         # ID
        self.pathOfDestFolder = pathToAllResultsFolder                          # All Results Folder (Dest of JS)
        self.pathOfJS_Folder = self.pathOfDestFolder + '/' + self.id            # JS Folder

        # All scRNAs
        self.pathOfJS_All_scRNAs_Folder = self.pathOfJS_Folder + '/scRNAs - without trigger'    # All scRNAs Folder
        self.pathOfJS_Zip_Folder = self.pathOfJS_All_scRNAs_Folder + '/_zipFiles'               # All/JS/Zip Folder

        # All scRNAsWithTrigger
        self.pathOfJS_All_scRNAs_Trigger_Folder = self.pathOfJS_Folder + '/scRNAs - with trigger'       # All scRNAs Folder
        self.pathOfJS_Zip_Trigger_Folder = self.pathOfJS_All_scRNAs_Trigger_Folder + '/_zipFiles'       # All/JS/Zip Folder


    def main(self):             return self.pathOfJS_Folder
    def scRNAs(self):           return self.pathOfJS_All_scRNAs_Folder
    def scRNAsTrigger(self):    return self.pathOfJS_All_scRNAs_Trigger_Folder
    def scRNAsZip(self):        return self.pathOfJS_Zip_Folder
    def scRNAsTriggerZip(self): return self.pathOfJS_Zip_Trigger_Folder


    def addPathOfScRNA(self, scRNA):
        (startIndex, endIndex) = scRNA.getStartEndTrigger()
        scRNAtitle = 'Detection Target X = ('+ str(startIndex) +'-'+ str(endIndex) +')'
        scRNAfolder = self.scRNAs() + '/' + scRNAtitle
        scRNAzipFolder = self.scRNAs() + '/_zipFiles'
        self.dictScRNA[str(scRNA)] = (scRNAfolder, scRNAzipFolder)
        return (scRNAfolder, scRNAzipFolder)

    def addPathOfScRNATrigger(self, scRNA):
        (startIndex, endIndex) = scRNA.getStartEndTrigger()
        scRNAtitle = 'Detection Target X = ('+ str(startIndex) +'-'+ str(endIndex) +')'
        scRNAfolder = self.scRNAsTrigger() + '/' + scRNAtitle
        scRNAzipFolder = self.scRNAsTrigger() + '/_zipFiles'
        self.dictScRNATrigger[str(scRNA)] = (scRNAfolder, scRNAzipFolder)
        return (scRNAfolder, scRNAzipFolder)

    def getScRNAFolder(self, scRNA):
        strScRNA = str(scRNA)
        if strScRNA not in self.dictScRNA:
            return self.addPathOfScRNA(scRNA)
        else:
            return self.dictScRNA.get(strScRNA)

    def getScRNATriggerFolder(self, scRNA):
        strScRNA = str(scRNA)
        if strScRNA not in self.dictScRNATrigger:
            return self.addPathOfScRNATrigger(scRNA)
        else:
            return self.dictScRNATrigger.get(strScRNA)


###################### Strand ######################

class Strand:
    def __init__(self, sequence = ''):
        super().__init__()
        self.seq = sequence

    def __str__(self):
        return self.seq

class HairpinA(Strand):
    def __init__(self, a, b, c, z):
        super().__init__()
        self.seq = z + reverseComplementRNA(c) + reverseComplementRNA(b) + reverseComplementRNA(a)

class HairpinB(Strand):
    def __init__(self, b, c, y, z):
        super().__init__()
        self.seq = z + reverseComplementRNA(c) + b + c + reverseComplementRNA(z) + reverseComplementRNA(y)

class XmRNA(Strand):
    def __init__(self, a, b, c, startBase, endBase):
        super().__init__()
        self.seq = a + b + c
        self.startBase = startBase
        self.endBase = endBase

class ScRNA:
    def __init__(self, hpA, hpB, mRNAX):
        super().__init__()
        self.hpA = hpA
        self.hpB = hpB
        self.mRNAX = mRNAX

    def display(self):
        s = 'A: ' + str(self.hpA)
        s += '\n'
        s += 'B: ' + str(self.hpB)
        s += '\n'
        s += 'X: ' + str(self.mRNAX) + \
             ' ('+ str(self.mRNAX.startBase) +' - '+ str(self.mRNAX.endBase) +')'
        return s

    def __str__(self):
        return self.display()

    def getStartEndTrigger(self):
        return (self.mRNAX.startBase, self.mRNAX.endBase)


###################################################################

def createListOfTargets(targetGene, silencingTargetLen):
    targetsIG = InvivoGenWizardSession(targetGene, silencingTargetLen)
    targetList = targetsIG.getTargetSeqList()
    return targetList

def mRNAdetectionTarget(triggerTarget):
    a = triggerTarget[:aLen]
    b = triggerTarget[aLen:(aLen+bLen)]
    c = triggerTarget[(aLen+bLen):(aLen+bLen+cLen)]
    return (a, b, c)

def mRNAsilencingTarget(silencingTarget):
    y = silencingTarget[:yLen]
    z = silencingTarget[yLen:(yLen+zLen)]
    return (y, z)

def createScRNA(a, b, c, y, z, startBase, endBase):
    hpA = HairpinA(a, b, c, z)
    hpB = HairpinB(b, c, y, z)
    mRNAX = XmRNA(a, b, c, startBase, endBase)
    scRNA = ScRNA(hpA, hpB, mRNAX)
    return scRNA


###################### Auxilary ######################

def sequenceFormat(seq):
    pattern = re.compile(r'\s+')
    res = re.sub(pattern, '', seq).upper()
    return str(res)

def dna2rna(dnaSeq):
    return dnaSeq.replace('T','U')

def rnaFormat(seq):
    seq = sequenceFormat(seq).replace('T','U')
    regexNot = '[^AUCG]'
    if(re.search(regexNot, seq)):
        print('input error.')
        return ''
    else:
        return seq

def complementaryRNA(sequence):
    comp = {
        'A': 'U',
        'U': 'A',
        'G': 'C',
        'C': 'G',
        }
    compSequence = ''
    for b in sequence:
        compSequence += comp[b]
    return compSequence

def reverseRNA(sequence):
    return sequence[::-1]

def reverseComplementRNA(sequence):
    return reverseRNA(complementaryRNA(sequence))

def pieceOfSeq(sequence, startBase, endBase, pieceLen):
    shortLen = endBase - startBase + 1
    pieceLen -= shortLen
    sideLen = int(pieceLen / 2)
    newStart = startBase-1 - sideLen
    newEnd = endBase + sideLen
    if (newStart < 0):
        newEnd += (-newStart)
        newStart = 0
    if (newEnd > len(sequence)):
        newStart -= (newEnd - len(sequence))
        newEnd = len(sequence)
    return sequence[newStart:newEnd]

def createDirIfNotExists(path):
    try:
        os.makedirs(path)
    except OSError:
        pass


# Run
run()